clear all
%close all
addpath('sub_Table')
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Setting: Parameters
% Generating artificial data from non-homothetic preference
% 1:AIDS, 2:Nh-CES (flag for type of preference)
paras.preference_type = 2;
paras.flag_sigma = 0; % 0 for constant sigma, 1 for variable sigma
N = 1000;
I = 3;% Need to adjust it when the number of goods changed.
T = 40; 
pols = [1,2,4,6,8,12];% JL polynomials
name = 'const_trend';
% The indirect utility of Nh-CES is calculated numerically from wages and prices.
% To make it easier to find a solution, we solve for indirect utility for Ns people, one for each t.
% Therefore, N must be a multiple of Ns.
Ns = 20;
% Nh-CES parameters
if paras.flag_sigma == 0
    paras.sig = 0.25;% from Comin et.al (2021)
else
    paras.sig = 10;
end
paras.eta = -2;% from Auer et.al. (2021)
paras.eps_vec = [0.2;1;1.65];% Non-homothetic parametrization from Comin et.al (2021)
%paras.eps_vec = [1;1;1];% Homotheric
paras.T = T;paras.T = T;paras.N = N;paras.I = I;paras.Ns = Ns;

% parameters for AIDS
AIDSparas.alp = [1/3;1/3;1/3];
AIDSparas.alp0 = 2;
AIDSparas.bet = [-0.15;-0.05;0.2]./4;
AIDSparas.gamm(:,:,:,1) = -[1/4;-1/8;-1/8];
AIDSparas.gamm(:,:,:,2) = -[-1/8;1/4;-1/8];
AIDSparas.gamm(:,:,:,3) = -[-1/8;-1/8;1/4];

paras.Kn = 2;
%% Setting: Price and Wage
% Price Schedule
pvec(1,1,:) = reshape(exp(linspace(0,2,T)),[1,1,T]);% (I,1,T)
pvec(2,1,:) = reshape(exp(linspace(0,1.7,T)),[1,1,T]);% (I,1,T)
pvec(3,1,:) = reshape(exp(linspace(0,1.4,T)),[1,1,T]);% (I,1,T)


% Generate artificial wage 
trendU = (reshape(exp(linspace(0,log(10),T)),[1,1,T]));
trendL = (reshape(exp(linspace(0,log(10),T)),[1,1,T])); 

trendVec = zeros(1,N,T);
for t = 1:T
    trendVec(1,:,t) = linspace(trendL(:,:,t),trendU(:,:,t),N);
end
rng(2,'twister')
% Generate artificial wage 
trend = (reshape(exp(linspace(0,log(10),T)),[1,1,T]));
rng(1,'twister')
I_vec_tmp = lognrnd(3.58,0.7,1,N);
I_vec = sort(I_vec_tmp(:,:,1),2).*trend;
paras.uvec_init = I_vec/10; % Initial value used to calculate the indirect utility of Comin.

% Calibration of Omega
% The share parameter is calibrated separately so that the budget shares of each good
% for the median household in the firstperiod are all the same (equal to one third for each good)
medianw = median(I_vec(1,:,1));
omega_tmp = 1./(medianw.^([0.2;1;1.65]*(1-paras.sig)));
omega_init = omega_tmp./sum(omega_tmp);
omega_init2 = omega_init(1:end-1,:);
options = optimoptions('fsolve', 'Algorithm', 'trust-region', 'TolFun', 1e-7,'Display','off');
omega_tmp2   =  fsolve( @(x) CalOmega(x,medianw,paras),omega_init2,options);  
paras.omega = [omega_tmp2;1-sum(omega_tmp2)];
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Generating Artificial data %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[v_vec, u_vec, B_vec,flag]= NhCESmax(I_vec,pvec,paras) ;
%%BBK Algorithm %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

[U_vec_FP] = CalMoneyMetric(I_vec, B_vec, pvec,1);
[U_vec] = CalMoneyMetric(I_vec, B_vec, pvec,0);
%%Run JL with loop
format shortG

%figure('Position',[100 100 800 400],'NumberTitle','off')
for i = 1:size(pols,2)
pol= pols(i);    
paras.Kn = pol;
warning("off")
 lastwarn('');
[LJMM_F,LJMM_SC,JLerrorType] = CalLJ_poly(I_vec, B_vec, pvec,paras);

run("sub_err.m")

ErrorJLVec(i,:) = ErrorJL;

JLErrorTypeVec(i,:) = JLerrorType; 

end

%% produce tables
caption_text = 'JL Algorithm';
latex_table = create_latex_table_JL(ErrorJLVec,pols, JLErrorTypeVec, caption_text)
%save_latex_table_to_file(latex_table, ['../fig/' name '_JL.tex']);
caption_text = 'BBK Algorithm';
latex_table = create_latex_table_1x4(ErrorBBK);
%save_latex_table_to_file(latex_table, ['../fig/' name '_BBK.tex']);

function [v_vec, u_vec, B_vec,flagok]= NhCESmax(I_vec,pvec,paras) 
%%Nonhomothetic Preference: Given W and P, solve for u 
T = paras.T;
N = paras.N;
I = paras.I;
Ns = paras.Ns;
v_vec = zeros(1,N,T);
uvec_init = paras.uvec_init;
for t = 1:T
    for ns = 1: Ns   
        Index = (N/Ns)*(ns-1)+1:(N/Ns)*(ns);        
            options = optimoptions('fsolve', 'TolFun', 1e-8, 'Display', 'off', 'Algorithm', 'trust-region');%,'Algorithm','Levenberg-Marquardt');
            [ v_vec(:,Index,t),fval,flag(:,Index,t)]  =  fsolve( @(x) NHU(x,I_vec(:,Index,t),pvec(:,:,t),paras),uvec_init(:,Index),options);  % pid

        if sum(flag(:,Index,t)) < 0
            options = optimoptions('fmincon','Algorithm','interior-point','OptimalityTolerance',1e-8);
            [ v_vec(:,Index,t),fval,flag(:,Index,t)]  = fmincon(@(x) norm(NHU(x,I_vec(:,Index,t),pvec(:,:,t),paras)), uvec_init(:,Index), [], [], [], [],[],[], [], options);
        end
    end 
end

flagok = max(max(flag < 0));
%%Obtain EV
[u_vec,B_vec]= CalEV(v_vec,pvec,paras);

function err = NHU(v_vec,I_vec,pvec,paras) 
if paras.flag_sigma ==1
    sig0 = paras.sig;
    xi = 0.01;
    xi2= 100;
    %paras.sig = max(min(sig0+log((v_vec).^paras.eta),5),1.5); % functional form from Auer (2021)
    paras.sig = ((1.5)^((xi2-1))+((5^((xi-1)/xi)+(10-2*log(v_vec)).^((xi-1)/xi)).^(xi/(xi-1))).^((xi2-1))).^(1/(xi2-1));
end
E = sum(paras.omega.*(((v_vec.^paras.eps_vec).*pvec).^(1-paras.sig))).^(1./(1-paras.sig));
err = abs(E-I_vec);
end

function [u_vec,B_vec]= CalEV(v_vec,pvec,paras)
paras.sig =paras.sig;
sig0 = paras.sig;
if paras.flag_sigma ==1
    xi = 0.01;
    xi2= 100;
%    paras.sig = max(min(sig0+log((v_vec).^paras.eta),5),1.5); % functional form from Auer (2021)
    paras.sig = ((1.5)^((xi2-1))+((5^((xi-1)/xi)+(10-2*log(v_vec)).^((xi-1)/xi)).^(xi/(xi-1))).^((xi2-1))).^(1/(xi2-1));
end
pvec_b = pvec(:,:,1);
u_vec = sum(paras.omega.*(((v_vec.^paras.eps_vec).*pvec_b).^(1-paras.sig))).^(1./(1-paras.sig));
tmp = paras.omega.*(((v_vec.^paras.eps_vec).*pvec).^(1-paras.sig));
B_vec = tmp./sum(tmp);
end

end

function  err = CalOmega(omega_init,I_vec,paras) 
    % helper function for solve for Omega 
    options = optimoptions('fsolve', 'Algorithm', 'trust-region', 'TolFun', 1e-5, 'Display', 'off');
    [ v_vec ]  =  fsolve( @(x) NHU(x,I_vec,omega_init,paras),I_vec/10,options);  

    omega = [omega_init;1-sum(omega_init)]; 
    tmp = omega.*(((v_vec.^paras.eps_vec)).^(1-paras.sig));
    B_vec = tmp./sum(tmp);

    err = (abs(B_vec(1:end,:)-1/paras.I));
    
    function err = NHU(v_vec,I_vec,omega_init,paras) 
    omega = [omega_init;1-sum(omega_init)];
    E = sum(omega.*(((v_vec.^paras.eps_vec)).^(1-paras.sig))).^(1./(1-paras.sig));
    err = abs(E-I_vec);
end
end